Skip to content

[Offload] Erase entries from JIT cache when program is destroyed #148847

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Jul 25, 2025

Conversation

RossBrunton
Copy link
Contributor

When unloadBinary is called, any entries in the JITEngine's cache
for that binary will be cleared. This fixes a nasty issue with
liboffload program handles. If two handles happen to have had the same
address (after one was free'd, for example), the cache would be hit and
return the wrong program.

@llvmbot
Copy link
Member

llvmbot commented Jul 15, 2025

@llvm/pr-subscribers-offload

Author: Ross Brunton (RossBrunton)

Changes

When unloadBinary is called, any entries in the JITEngine's cache
for that binary will be cleared. This fixes a nasty issue with
liboffload program handles. If two handles happen to have had the same
address (after one was free'd, for example), the cache would be hit and
return the wrong program.


Full diff: https://github.com/llvm/llvm-project/pull/148847.diff

3 Files Affected:

  • (modified) offload/plugins-nextgen/common/include/JIT.h (+8-2)
  • (modified) offload/plugins-nextgen/common/src/JIT.cpp (+17-9)
  • (modified) offload/plugins-nextgen/common/src/PluginInterface.cpp (+3)
diff --git a/offload/plugins-nextgen/common/include/JIT.h b/offload/plugins-nextgen/common/include/JIT.h
index 8c530436a754b..68a6d039c4641 100644
--- a/offload/plugins-nextgen/common/include/JIT.h
+++ b/offload/plugins-nextgen/common/include/JIT.h
@@ -55,6 +55,10 @@ struct JITEngine {
   process(const __tgt_device_image &Image,
           target::plugin::GenericDeviceTy &Device);
 
+  /// Remove \p Image from the jit engine's cache
+  void erase(const __tgt_device_image &Image,
+             target::plugin::GenericDeviceTy &Device);
+
 private:
   /// Compile the bitcode image \p Image and generate the binary image that can
   /// be loaded to the target device of the triple \p Triple architecture \p
@@ -90,10 +94,12 @@ struct JITEngine {
     LLVMContext Context;
 
     /// Output images generated from LLVM backend.
-    SmallVector<std::unique_ptr<MemoryBuffer>, 4> JITImages;
+    DenseMap<const __tgt_device_image *, std::unique_ptr<MemoryBuffer>>
+        JITImages;
 
     /// A map of embedded IR images to JITed images.
-    DenseMap<const __tgt_device_image *, __tgt_device_image *> TgtImageMap;
+    DenseMap<const __tgt_device_image *, std::unique_ptr<__tgt_device_image>>
+        TgtImageMap;
   };
 
   /// Map from (march) "CPUs" (e.g., sm_80, or gfx90a), which we call compute
diff --git a/offload/plugins-nextgen/common/src/JIT.cpp b/offload/plugins-nextgen/common/src/JIT.cpp
index c82a06e36d8f9..00720fa2d8103 100644
--- a/offload/plugins-nextgen/common/src/JIT.cpp
+++ b/offload/plugins-nextgen/common/src/JIT.cpp
@@ -285,8 +285,8 @@ JITEngine::compile(const __tgt_device_image &Image,
 
   // Check if we JITed this image for the given compute unit kind before.
   ComputeUnitInfo &CUI = ComputeUnitMap[ComputeUnitKind];
-  if (__tgt_device_image *JITedImage = CUI.TgtImageMap.lookup(&Image))
-    return JITedImage;
+  if (CUI.TgtImageMap.contains(&Image))
+    return CUI.TgtImageMap[&Image].get();
 
   auto ObjMBOrErr = getOrCreateObjFile(Image, CUI.Context, ComputeUnitKind);
   if (!ObjMBOrErr)
@@ -296,17 +296,15 @@ JITEngine::compile(const __tgt_device_image &Image,
   if (!ImageMBOrErr)
     return ImageMBOrErr.takeError();
 
-  CUI.JITImages.push_back(std::move(*ImageMBOrErr));
-  __tgt_device_image *&JITedImage = CUI.TgtImageMap[&Image];
-  JITedImage = new __tgt_device_image();
+  CUI.JITImages.insert({&Image, std::move(*ImageMBOrErr)});
+  auto &ImageMB = CUI.JITImages[&Image];
+  CUI.TgtImageMap.insert({&Image, std::make_unique<__tgt_device_image>()});
+  auto &JITedImage = CUI.TgtImageMap[&Image];
   *JITedImage = Image;
-
-  auto &ImageMB = CUI.JITImages.back();
-
   JITedImage->ImageStart = const_cast<char *>(ImageMB->getBufferStart());
   JITedImage->ImageEnd = const_cast<char *>(ImageMB->getBufferEnd());
 
-  return JITedImage;
+  return JITedImage.get();
 }
 
 Expected<const __tgt_device_image *>
@@ -324,3 +322,13 @@ JITEngine::process(const __tgt_device_image &Image,
 
   return &Image;
 }
+
+void JITEngine::erase(const __tgt_device_image &Image,
+                      target::plugin::GenericDeviceTy &Device) {
+  std::lock_guard<std::mutex> Lock(ComputeUnitMapMutex);
+  const std::string &ComputeUnitKind = Device.getComputeUnitKind();
+  ComputeUnitInfo &CUI = ComputeUnitMap[ComputeUnitKind];
+
+  CUI.TgtImageMap.erase(&Image);
+  CUI.JITImages.erase(&Image);
+}
diff --git a/offload/plugins-nextgen/common/src/PluginInterface.cpp b/offload/plugins-nextgen/common/src/PluginInterface.cpp
index 81b9d423e13d8..94a050b559efe 100644
--- a/offload/plugins-nextgen/common/src/PluginInterface.cpp
+++ b/offload/plugins-nextgen/common/src/PluginInterface.cpp
@@ -854,6 +854,9 @@ Error GenericDeviceTy::unloadBinary(DeviceImageTy *Image) {
       return Err;
   }
 
+  if (Image->getTgtImageBitcode())
+    Plugin.getJIT().erase(*Image->getTgtImageBitcode(), Image->getDevice());
+
   return unloadBinaryImpl(Image);
 }
 

@@ -854,6 +854,9 @@ Error GenericDeviceTy::unloadBinary(DeviceImageTy *Image) {
return Err;
}

if (Image->getTgtImageBitcode())
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I forget how this works, we have the image passed in by the user and the one created by the backend right? I'm wondering if we should just check the magic bytes there.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we are compiling bitcode, then this will be set to a __tgt_device_image * for the original bitcode that olCreateProgram copied. It is nullptr if no compilation took place (and so we don't need to tell the JIT to remove anything.

@@ -285,8 +285,8 @@ JITEngine::compile(const __tgt_device_image &Image,

// Check if we JITed this image for the given compute unit kind before.
ComputeUnitInfo &CUI = ComputeUnitMap[ComputeUnitKind];
if (__tgt_device_image *JITedImage = CUI.TgtImageMap.lookup(&Image))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't follow in what cases this Image is not a unique one?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The data Image points to happens to be inside ol_program_impl_t, but something similar to this:

Image *MyImage = new Image();
delete MyImage;
Image *MyImage2 = new Image();
// MyImage may equal MyImage2

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So the image can be created on the fly. In that case, we probably still want to cache that, but use a different key that can effectively tell the two images apart.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The JIT engine uses the Image address to identify input images; once the Image is dropped, we lose the ability to look it up in the cache, so there's no reason to keep the entry around. What use case are you thinking of?

Copy link
Contributor

@shiltian shiltian Jul 15, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

but your case demonstrates that image address could be not unique, even for the "same" image.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As long as the image pointer is actually alive, it's a unique identifier for the JIT'ed binary. The issue only happens once the input image is free'd.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Re-reading this thread and I think I worded things a bit confusingly; by "Image address" I mean &Image rather than Image->ImageStart.

I'm not sure there's a key that we can use to uniquely identify Images across create/destroy boundaries, nor can I see a use case for that.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A SHA value based on the contents of the image could do it. In that way, even the image "handler" can be created and destroyed multiple times, the contents of the image is expected to be the same.

I'll not be the blocker here. This is less ideal but I'm fine with it.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We need to seriously rework the image handling as whole.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm still unsure of why we need to keep them around after the backing image has been dropped. I can't see the user constantly recreating the same buffer many times with the same contents and expecting high performance.

When `unloadBinary` is called, any entries in the JITEngine's cache
for that binary will be cleared. This fixes a nasty issue with
liboffload program handles. If two handles happen to have had the same
address (after one was free'd, for example), the cache would be hit and
return the wrong program.
@RossBrunton
Copy link
Contributor Author

@shiltian @jhuber6 Can I get this looked at again?

@RossBrunton RossBrunton merged commit ae44418 into llvm:main Jul 25, 2025
9 checks passed
mahesh-attarde pushed a commit to mahesh-attarde/llvm-project that referenced this pull request Jul 28, 2025
…m#148847)

When `unloadBinary` is called, any entries in the JITEngine's cache
for that binary will be cleared. This fixes a nasty issue with
liboffload program handles. If two handles happen to have had the same
address (after one was free'd, for example), the cache would be hit and
return the wrong program.
ajaden-codes pushed a commit to Jaddyen/llvm-project that referenced this pull request Jul 28, 2025
…m#148847)

When `unloadBinary` is called, any entries in the JITEngine's cache
for that binary will be cleared. This fixes a nasty issue with
liboffload program handles. If two handles happen to have had the same
address (after one was free'd, for example), the cache would be hit and
return the wrong program.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants